//
// Created by Göksu Güvendiren on 2019-05-14.
//

#include "Scene.hpp"


void Scene::buildBVH() {
    printf(" - Generating BVH...\n\n");
    this->bvh = new BVHAccel(objects, 1, BVHAccel::SplitMethod::NAIVE);
}

Intersection Scene::intersect(const Ray &ray) const
{
    return this->bvh->Intersect(ray);
}

void Scene::sampleLight(Intersection &pos, float &pdf) const
{
    float emit_area_sum = 0;
    for (uint32_t k = 0; k < objects.size(); ++k) {
        if (objects[k]->hasEmit()){
            emit_area_sum += objects[k]->getArea();
        }
    }
    float p = get_random_float() * emit_area_sum;
    emit_area_sum = 0;
    for (uint32_t k = 0; k < objects.size(); ++k) {
        if (objects[k]->hasEmit()){
            emit_area_sum += objects[k]->getArea();
            if (p <= emit_area_sum){
                objects[k]->Sample(pos, pdf);
                break;
            }
        }
    }
}

bool Scene::trace(
        const Ray &ray,
        const std::vector<Object*> &objects,
        float &tNear, uint32_t &index, Object **hitObject)
{
    *hitObject = nullptr;
    for (uint32_t k = 0; k < objects.size(); ++k) {
        float tNearK = kInfinity;
        uint32_t indexK;
        Vector2f uvK;
        if (objects[k]->intersect(ray, tNearK, indexK) && tNearK < tNear) {
            *hitObject = objects[k];
            tNear = tNearK;
            index = indexK;
        }
    }


    return (*hitObject != nullptr);
}

// Implementation of Path Tracing
Vector3f Scene::castRay(const Ray &ray, int depth) const
{
    // TO DO Implement Path Tracing Algorithm here
    Intersection inter = intersect(ray);
    
    if (inter.happened) {
        if (inter.m->hasEmission()) {
            if (depth == 0) {
                return inter.m->getEmission();
            }
            else return Vector3f(0, 0, 0);
        }
    
        Vector3f L_dir(0, 0, 0);
        Vector3f L_indir(0, 0, 0);
    
        Intersection lightInter;
        float light_pdf = 0.0f;
        sampleLight(lightInter, light_pdf);
    
        auto& objectNormal = inter.normal;
        auto& lightNormal = lightInter.normal;
    
        auto& objPos = inter.coords;
        auto& lightPos = lightInter.coords;
    
        auto dir = lightPos - objPos;
        float distance = dotProduct(dir, dir);
        dir = dir.normalized();
    
        Ray light(objPos, dir);
        Intersection lightCheck = intersect(light);
        if ((lightCheck.coords - lightPos).norm() < 1e-2) {
            auto f_r = inter.m->eval(ray.direction, dir, objectNormal);
            L_dir = lightInter.emit * f_r * dotProduct(dir, objectNormal) * dotProduct(-dir, lightNormal) / distance / light_pdf;
        }
        if (get_random_float() < RussianRoulette) {
    
            auto reflectDir = inter.m->sample(ray.direction, objectNormal).normalized();
        
            Ray reflectRay(objPos, reflectDir);
            Intersection reflectInter = intersect(reflectRay);
            if (reflectInter.happened && !reflectInter.m->hasEmission()) {
                float pdf = inter.m->pdf(ray.direction, reflectDir, objectNormal);
                Vector3f f_r = inter.m->eval(ray.direction, reflectDir, objectNormal);
                L_indir = castRay(reflectRay, depth + 1) * f_r * dotProduct(reflectDir, objectNormal) / pdf / RussianRoulette;
            }
        }
        return L_dir + L_indir;
    }
    return Vector3f(0, 0, 0);
}